/*********************************************************
* unit:    quantise2         release 0.33                *
* purpose: Advanced but slower palette quantiser.        *
* licency:     GPL or LGPL                               *
* Copyright: (c) 2024-2025 Jaroslav Fojtik               *
**********************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#if defined _OPENMP
 #ifdef _MSC_VER
  #include <Windows.h>
  #if defined(_X86_)
   extern "C"
   {
    void InterlockedAdd64(long long *Value, unsigned ValueAdd);
   }
  #endif
 #else
   extern "C"
   {
    void InterlockedAdd64(long long *Value, unsigned ValueAdd);
    void InterlockedIncrement(long *Value);
   }
   void InterlockedAdd64(long long *Value, unsigned ValueAdd)
   {
     __sync_fetch_and_add(Value,ValueAdd);
   }
   void InterlockedIncrement(long *Value)
   {
    __sync_fetch_and_add(Value,1);
   }
 #endif
#endif

#include "raster.h"
#include "ras_prot.h"

#include "quantise.h"

#include "OmpFix.h"


inline int32_t sqr(const int32_t val)
{
  return val*val;
}


inline double sqr(const int64_t val)
{
  return val*(double)val;
}

inline double sqr(const double val)
{
  return val*val;
}


//#define _DEBUG

///////////////////////////////////////////////////

class OctElement2;

typedef struct
{
  uint64_t Criteria;
  OctElement2 *OE;
  uint8_t i, j;
} TMergeCriteria;


class OctElement2
{
public:
  OctElement2();
  ~OctElement2();
  void Expand(uint8_t Levels);
  void Cleanup(void);
  void Add(const RGBQuad &RGB);
  unsigned Leaves(void) const;
  void Collapse(const uint8_t i, const uint8_t j);
  bool FindMerge(TMergeCriteria &MC);
  void FeedPalette(APalette *Palette, unsigned &idx) const;
  void FeedPalette(APalette *Palette) const {unsigned idx=0;FeedPalette(Palette,idx);};
  unsigned AvgR(void) const {return SumR/Hits;};
  unsigned AvgG(void) const {return SumG/Hits;};
  unsigned AvgB(void) const {return SumB/Hits;};
  unsigned Delta2R(void) const {return (unsigned)(SumR2 - SumR*SumR/Hits);};
  unsigned Delta2G(void) const {return (unsigned)(SumG2 - SumG*SumG/Hits);};
  unsigned Delta2B(void) const {return (unsigned)(SumB2 - SumB*SumB/Hits);};
  unsigned Delta2(void) const;
#ifdef _DEBUG
  void Print(void);
#endif
  void Remove(const uint8_t j);

  int64_t SumR, SumG, SumB;
  int64_t SumR2, SumG2, SumB2;
  int32_t Hits;
  uint16_t MinR, MaxR;
  uint16_t MinG, MaxG;
  uint16_t MinB, MaxB;
  uint8_t DownCount;
  uint8_t Flags;
  OctElement2 **Down;
};


OctElement2::OctElement2()
{
  Down = NULL;
  DownCount = 0;
  Hits = 0;
  SumR = SumG = SumB = 0;
  SumR2 = SumG2 = SumB2 = 0;
}


OctElement2::~OctElement2()
{
  if(Down)
  {
    while(DownCount>0)
    {
      DownCount--;
      if(Down[DownCount]!=NULL)
        {delete Down[DownCount]; Down[DownCount]=NULL;}
    }
    free(Down);
    Down = NULL;
  }
}


unsigned OctElement2::Delta2(void) const
{
  return (unsigned)(SumR2+SumG2+SumB2 - (uint64_t)(SumR*SumR + SumG*SumG + SumB*SumB)/Hits);
}



void OctElement2::Expand(uint8_t Levels)
{
  if(Levels==0) return;
  if(Down==NULL)
  {
    DownCount = 0;
    Down = (OctElement2**)malloc(8*sizeof(OctElement2*));
    if(Down)
      do
      {
        OctElement2 *const pOct = new OctElement2;
        if(0 == (DownCount & 1))
        {
          pOct->MinR = MinR;
          pOct->MaxR = MinR + (MaxR-MinR)/2;
        }
        else
        {
          pOct->MinR = 1 + MinR + (MaxR-MinR)/2;
          pOct->MaxR = MaxR;
        }
        if(0 == (DownCount & 2))
        {
          pOct->MinG = MinG;
          pOct->MaxG = MinG + (MaxG-MinG)/2;
        }
        else
        {
          pOct->MinG = 1 + MinG + (MaxG-MinG)/2;
          pOct->MaxG = MaxG;
        }

        if(0 == (DownCount & 4))
        {
          pOct->MinB = MinB;
          pOct->MaxB = MinB + (MaxB-MinB)/2;
        }        
        else
        {
          pOct->MinB = 1 + MinB + (MaxB-MinB)/2;
          pOct->MaxB = MaxB;
        }
        
        pOct->Expand(Levels-1);
        Down[DownCount] = pOct;        
      } while(++DownCount<8);
  }
}


#ifdef _DEBUG
void OctElement2::Print(void)
{
  double R,G,B;
  R = AvgR();
  G = AvgG();
  B = AvgB();
  //printf("\n[<%u,%u>; <%u,%u>; <%u,%u>] [%.2f %.2f %.2f] n=%u Delta=%.1f[%.1f %.1f %.1f]", MinR, MaxR, MinG, MaxG, MinB, MaxB, 
  //		R,G,B, Hits, Delta2(), Delta2R(), Delta2G(), Delta2B());
  printf("\n[<%u,%u>; <%u,%u>; <%u,%u>] [%.2f %.2f %.2f] n=%u", MinR, MaxR, MinG, MaxG, MinB, MaxB, 
  		R,G,B, Hits);

  for(uint8_t i=0; i<DownCount; i++)
  {
    if(Down[i]!=NULL) Down[i]->Print();
  }
}
#endif


/// Remove all octnodes with zero occurance.
void OctElement2::Cleanup(void)
{
uint8_t Remove=0;
  if(DownCount==0 || Down==NULL) return;
  for(uint8_t i=0; i<DownCount; i++)
  {
    if(Down[i]==NULL)
    {
      Remove++;
      continue;
    }
    if(Down[i]->Hits==0)
    {
      delete Down[i];
      Down[i] = NULL;
      Remove++;
      continue;
    }
    Down[i]->Cleanup();
  }

  if(Remove>0)
  {
    if(Remove>=DownCount)	// The subtree is empty.
    {
      free(Down);
      Down = NULL;
      DownCount = 0;
      return;
    }

    if(DownCount-Remove==1)	// Only one element left, subtree collapsed.
    {
      for(uint8_t i=0; i<DownCount; i++)
        {if(Down[i]!=NULL) delete Down[i];}
      free(Down);
      Down = NULL;
      DownCount = 0;
      return;
    }

    OctElement2 **Down2 = (OctElement2**)malloc(sizeof(OctElement2*)*(DownCount-Remove));
    if(Down2==NULL) return;		// memory allocation failure.

    Remove = 0;
    for(uint8_t i=0; i<DownCount; i++)
    {
      if(Down[i]==NULL) continue;
      Down2[Remove++] = Down[i];
    }
    free(Down);
    Down = Down2;
    DownCount = Remove;
  }
}


unsigned OctElement2::Leaves(void) const
{
  if(DownCount==0 || Down==NULL) return 1;
  unsigned leaves = 0;
  for(uint8_t i=0; i<DownCount; i++)
  {
    if(Down[i]!=NULL)
      leaves += Down[i]->Leaves();
  }
  return leaves;
}


bool OctElement2::FindMerge(TMergeCriteria &MC)
{
uint8_t i;
uint8_t Mask = 0;
  if(DownCount<=1 || Down==NULL) return false;

  for(i=0; i<DownCount; i++)
  {
    if(Down[i]==NULL) continue;
    if(Down[i]->FindMerge(MC)) continue;
    Mask |= 1<<i;
  }

  if(Mask>0)
  {
    for(i=0; i<DownCount-1; i++)
    {
      if((Mask&(1<<i))==0) continue;
      for(uint8_t j=i+1; j<DownCount; j++)
      {
        if((Mask&(1<<j))==0) continue;
        uint64_t Criteria_ij = 
          (Down[i]->SumR2+Down[j]->SumR2) + (Down[i]->SumG2+Down[j]->SumG2) + (Down[i]->SumB2+Down[j]->SumB2)
          - (sqr(Down[i]->SumR+Down[j]->SumR) + sqr(Down[i]->SumG+Down[j]->SumG) + sqr(Down[i]->SumB+Down[j]->SumB)) / (Down[i]->Hits+Down[j]->Hits);
        if(Criteria_ij <= MC.Criteria)
        {
          MC.i = i;
          MC.j = j;
          MC.Criteria = Criteria_ij;
          MC.OE = this;
        }
      }
    }
  }
  return true;
}


void OctElement2::Collapse(const uint8_t i, const uint8_t j)
{
  if(DownCount==0 || Down==NULL) return;
  if(i>=DownCount || j>=DownCount) return;
  if(Down[i]==NULL || Down[j]==NULL) return;

  if(DownCount <= 2)
  {
    while(DownCount > 0)
    {
      DownCount--;
      if(Down[DownCount]!=NULL)
        delete Down[DownCount];
    }
    free(Down);
    Down = NULL;
    return;
  }

  Down[i]->Hits += Down[j]->Hits;
  Down[i]->SumR += Down[j]->SumR;
  Down[i]->SumR2 += Down[j]->SumR2;
  Down[i]->SumG += Down[j]->SumG;
  Down[i]->SumG2 += Down[j]->SumG2;
  Down[i]->SumB += Down[j]->SumB;
  Down[i]->SumB2 += Down[j]->SumB2;

  if(Down[j]->MinR < Down[i]->MinR) Down[i]->MinR=Down[j]->MinR;
  if(Down[j]->MinG < Down[i]->MinG) Down[i]->MinG=Down[j]->MinG;
  if(Down[j]->MinB < Down[i]->MinB) Down[i]->MinB=Down[j]->MinB;
  if(Down[j]->MaxR > Down[i]->MaxR) Down[i]->MaxR=Down[j]->MaxR;
  if(Down[j]->MaxG > Down[i]->MaxG) Down[i]->MaxG=Down[j]->MaxG;
  if(Down[j]->MaxB > Down[i]->MaxB) Down[i]->MaxB=Down[j]->MaxB;

  delete Down[j];
  Down[j] = NULL;

  OctElement2 **Down2 = (OctElement2**)malloc(sizeof(OctElement2*)*(DownCount-1));
  if(Down2==NULL) return;		// memory allocation failure.
  if(j>0) memcpy(Down2,Down, j*sizeof(OctElement2*));
  if(DownCount-j-1>0) memcpy(Down2+j,Down+j+1, sizeof(OctElement2*)*(DownCount-j-1));
  free(Down);
  Down = Down2;
  DownCount--;
}


void OctElement2::Add(const RGBQuad &RGB)
{
#if defined _OPENMP
  InterlockedAdd64(&SumR,RGB.R);
  InterlockedAdd64(&SumG,RGB.G);
  InterlockedAdd64(&SumB,RGB.B);
  InterlockedAdd64(&SumR2, RGB.R*(uint32_t)RGB.R);
  InterlockedAdd64(&SumG2, RGB.G*(uint32_t)RGB.G);
  InterlockedAdd64(&SumB2, RGB.B*(uint32_t)RGB.B);
  InterlockedIncrement((long*)&Hits);
#else
  SumR += RGB.R;
  SumG += RGB.G;
  SumB += RGB.B;
  SumR2 += RGB.R*(uint32_t)RGB.R;
  SumG2 += RGB.G*(uint32_t)RGB.G;
  SumB2 += RGB.B*(uint32_t)RGB.B;
  Hits++;
#endif

  switch(DownCount)
  {
    case 0: return;
    case 8:		// Full rank, use speedup.
        {
          uint8_t i = 0;
          if(RGB.R > MinR+(MaxR-MinR)/2) i|=1;
          if(RGB.G > MinG+(MaxG-MinG)/2) i|=2;
          if(RGB.B > MinB+(MaxB-MinB)/2) i|=4;
          Down[i]->Add(RGB);
        }
        return;
  }

  for(uint8_t i=0; i<DownCount; i++)
  {
    OctElement2 * const pDown = Down[i];
    if(pDown!=NULL) 
    {
      if(RGB.R < pDown->MinR) continue;
      if(RGB.G < pDown->MinG) continue;
      if(RGB.B < pDown->MinB) continue;
      if(RGB.R > pDown->MaxR) continue;
      if(RGB.G > pDown->MaxG) continue;
      if(RGB.B > pDown->MaxB) continue;
      pDown->Add(RGB);
    }
  }
}


void OctElement2::FeedPalette(APalette *Palette, unsigned &idx) const
{
  if(DownCount==0 || Down==NULL)
  {
    Palette->setR(idx,AvgR());
    Palette->setG(idx,AvgG());
    Palette->setB(idx,AvgB());
    idx++;
  }
  else
  {
    for(uint8_t i=0; i<DownCount; i++)
    {
      if(Down[i]!=NULL) Down[i]->FeedPalette(Palette,idx);
    }
  }
}


///////////////////////////////////////////////////


APalette *FindIndices2(Raster2DAbstract *RasterRGB, unsigned indices)
{
OctElement2 OctRoot;

  if(RasterRGB==NULL || RasterRGB->Size1D<=0 || RasterRGB->Size2D<=0 || indices<=1) return NULL;

  OctRoot.MaxR = OctRoot.MaxG = OctRoot.MaxB = 255;
  OctRoot.MinR = OctRoot.MinG = OctRoot.MinB = 0;
  OctRoot.Expand(5);

#pragma omp parallel
  {
    Raster1DAbstract * const RasRGBrow 
	  =  (RasterRGB->Channels()==3) ? CreateRaster1DRGB(0,RasterRGB->GetPlanes()/3) : CreateRaster1DRGBA(0,RasterRGB->GetPlanes()/4);
    RasRGBrow->Shadow = true;
    RasRGBrow->Size1D = RasterRGB->Size1D;
#pragma omp for ordered schedule(dynamic)
    for(UNS_OMP_IT y=0; y<RasterRGB->Size2D; y++)
    {
      RasRGBrow->Data1D = RasterRGB->GetRow(y);
      if(RasRGBrow->Data1D==NULL) continue;    
      for(unsigned x=0; x<RasterRGB->Size1D; x++)
      {
        RGBQuad RGB;
        RasRGBrow->Get(x,&RGB);
        OctRoot.Add(RGB);
      }
    }

    delete RasRGBrow;
  }

  OctRoot.Cleanup();

#ifdef _DEBUG
  //OctRoot.Print();
#endif

  unsigned LeafCount = OctRoot.Leaves();

  while(LeafCount>indices)
  {
   TMergeCriteria MC;
   MC.OE = NULL;

   MC.Criteria = OctRoot.Delta2() + 1;
   OctRoot.FindMerge(MC);
   if(MC.OE==NULL)
	break;		// Algorithm failed.
   MC.OE->Collapse(MC.i,MC.j);   

   LeafCount--;
// debug only
//   unsigned i=OctRoot.Leaves();
//   if(LeafCount != i)
//     printf("mismatch %u:%u",i,LeafCount);
  }
  // LeafCount = OctRoot.Leaves();

  int Planes = 1;
  if(indices >    2) Planes=2;
  if(indices >    4) Planes=4;
  if(indices >   16) Planes=8;
  if(indices >  256) Planes=16;
  if(indices >65536) Planes=32;
  if(Planes>=16) return NULL;

  APalette *Palette = BuildPalette(indices,8);
  if(Palette)
      OctRoot.FeedPalette(Palette);
  return Palette;
}
